# MATHEMATICAL/STATISTICAL CONVENIENCE FUNCTIONS -------------------------------

#' Calculates outcome of partial sum / NPV of $1 per period coupon.
#' @param r_rate Numeric: discount / real interest rate (decimal)
#' @param t_time Integer: discrete-time length of NPV stream
#' @return Numeric: multiplicative NPV factor to convert $1 per-period coupon
npv_factor <- function(r_rate, t_time) {
  (1 - ((1 / (1 + r_rate))^(t_time))) / (1 - (1 / (1 + r_rate)))
}

# #' Calculates the annuity factor that translate a lump-sum / stock of weatlh
# #' into a per-period annual payment. The function requires setting the length
# #' over which the annuitization is taking place. Users can either supply this
# #' directly (using 't_time') or can supply both an age of annuity onset
# #' ('age_at_onset') AND life expectancy ('life_exp'), yielding a length equal to
# #' (life_exp - age_at_onset).
# #' @param r_rate Numeric: discount / real interest rate (decimal)
# #' @param age_at_onset Integer: age at which annuitization starts
# #' @param life_exp Integer: life expectancy / terminal age of annuity
# #' @param t_time Integer: discrete-time length of annuity stream (default is NA)
# #' @return Numeric: multiplicative factor that converts wealth to per-period
# per_period_annuity_factor <-
#   function(r_rate,
#            age_at_onset,
#            life_exp,
#            t_time = as.integer(NA)) {
#     if (is.na(t_time)) {
#       (r_rate / (1 + r_rate)) *
#         (1 / (1 - (1 / (1 + r_rate))^(life_exp - age_at_onset + 1)))
#     } else {
#       (r_rate / (1 + r_rate)) *
#         (1 / (1 - (1 / (1 + r_rate))^(t_time + 1)))
#     }
# }

#' Calculates weighted covariance function admitting frequency weights and
#' agrees with wtd.var() from Hmisc package.
#' $@$ GOT THIS FROM SOMEWHERE; SHOULD CREDIT $@$
#' @param x Numeric: vector/matrix of values
#' @param frqwt Numeric: weights to apply to each value; default is equal.
#' @param unbiased Logical: TRUE (default) accounts for dof correction
#' @return List: weighted covariance, weighted average of `x`, and 'unbiased'
wtd_cov <- function(x, frqwt = rep(1, nrow(x)), unbiased = TRUE) {
  if (is.data.frame(x)) x <- as.matrix(x)
  n <- sum(frqwt)
  center <- colSums(frqwt * x) / n
  xcw <- sqrt(frqwt) * sweep(x, 2, center, check.margin = TRUE)
  cov <- crossprod(xcw)
  cov <- if (unbiased) {
    cov / (n - 1)
  } else {
    cov / n
  }
  list(cov = cov, center = center, unbiased = unbiased)
}

## delta_method function
## near-verbatim from 'msm' package
delta_method <- function (g, mean, cov, ses = TRUE) {
  ## Var (G(x))  =  D_g * Var(X) * t(D_g)
  cov <- as.matrix(cov)
  n <- length(mean)
  if (!is.list(g)){
    g <- list(g)
  }
  if ((dim(cov)[1] != n) || (dim(cov)[2] != n)){
    stop(print(sprintf("'cov' is a %s X %s matrix, but should be a square %s X %s matrix.", dim(cov)[1], dim(cov)[2], n, n)))
  }
  syms <- paste("x", 1:n, sep = "")
  for (i in 1:n) assign(syms[i], mean[i])
  D_g <- t(sapply(g, function(form) {
    ## 1) differentiate each formula in the list
    ## 2) evaluate at the supplied estimated / plug-in value
    ## 3) take these elements and make Jacobian row-by-row
    as.numeric(attr(eval(deriv(form, syms)), "gradient"))
  }))
  new_cov <- D_g %*% cov %*% t(D_g)
  if (ses == TRUE) {
    result <- sqrt(diag(new_cov))
  } else{
    result <- diag(new_cov)
  }
  return(result)
}

makeRatioVcov <- function(coefs_fs, coefs_rf, vcov_fs, vcov_rf){

  # As we don't have cross-equation correlations, we take the approach as in, for example:
  ## Dee and Evans (2003), Pacini and Windmeijer (2016), and Sampat and Williams (2018)
  ## e.g., treating the cross-equation covariance of residuals as 0

  # In this case, variance of ratio (using a first-order Taylor expansion and the delta method):
  # (vcov_rf[i,i] + (coefs_rf[i.i]/coefs_fs[i,i])^2 * vcov_fs[i,i]) / (coefs_fs[i,i])^2

  # coefs_fs = copy(catt_coefs_fs);
  # coefs_rf = copy(catt_coefs_rf);
  # vcov_fs = copy(catt_vcov_fs);
  # vcov_rf = copy(catt_vcov_rf)

  # make a diagonal matrix of the first stage inverse
  C_hat <- solve(as.matrix(diag(x = coefs_fs, nrow = length(coefs_fs))))
  # include nrow in diag()  as can have inconsistent behavior without it

  # make a diagonal matrix of the RF variance-covariance matrix
  # use notation to match Pacini and Windmeijer (2016): Var_pi_y1
  rf_vcov_diag <- diag(vcov_rf)
  Var_pi_y1 <- as.matrix(diag(x = rf_vcov_diag, nrow = length(rf_vcov_diag)))

  # make a diagonal matrix of the FS variance-covariance matrix
  # use notation to match Pacini and Windmeijer (2016): Var_pi_x2
  fs_vcov_diag <- diag(vcov_fs)
  Var_pi_x2 <- as.matrix(diag(x = fs_vcov_diag, nrow = length(fs_vcov_diag)))

  # make a ratio product term
  # use notation to match Pacini and Windmeijer (2016): (beta_ts2sls' %x% C_hat)
  # in our case, C_hat is diagonal, so can just do element-wise multiplication
  beta_ts2sls <- coefs_rf/coefs_fs
  ratio_prod <- beta_ts2sls * C_hat

  # eqn 12 in Pacini and Windmeijer (2016)
  # recall: in R, matA * matB is done as t(A) * B
  # identical( t(C_hat) %*% Var_pi_y1, crossprod(C_hat, Var_pi_y1) )
  first_quad_form <- crossprod(crossprod(C_hat, Var_pi_y1),t(C_hat)) # C_hat * Var_pi_y1 * C_hat'
  second_quad_form <- crossprod(crossprod(ratio_prod, Var_pi_x2),t(ratio_prod)) # (beta_ts2sls' %x% C_hat) * Var_pi_x2 * (beta_ts2sls' %x% C_hat)'
  beta_ts2sls_vcov <- first_quad_form + second_quad_form

  return(beta_ts2sls_vcov)
}

addRatioSEs <- function(results_object,
                        omitted_event_time,
                        calculate_collapse_estimates,
                        homogeneous_ATT,
                        anticipation,
                        model_type,
                        trim = FALSE,
                        trim_iqr_factor = 1.5,
                        use_custom_collapse_table = FALSE,
                        collapse_table_dt = NULL){
  # results_object = copy(r);
  # omitted_event_time = omitted_event_time;
  # calculate_collapse_estimates = TRUE;
  # homogeneous_ATT = FALSE;
  # anticipation = anticipation;
  # model_type = "ratio"
  # trim = FALSE
  # trim_iqr_factor = 1.5

  # Load require packages
  # 'results_object' should be a data.table, so hopefuly already loaded
  if( !("package:data.table" %in% search()) ){
    library(data.table, verbose = TRUE)
  }

  figdata <- results_object[[1]]

  # R oddity -- sometimes 'grouping' column in figdata is logical rather than character
  # want it to be character
  if(("grouping" %in% colnames(figdata)) & (class(figdata$grouping) != "character")){
    figdata[, grouping := as.character(grouping)]
  }

  # from Wald_ES internals
  # if (homogeneous_ATT == FALSE) {
  #   return_list[[4]] <- list(catt_coefs_rf, catt_vcov_rf)
  #   if (!is.null(endog_var)) {
  #     return_list[[5]] <- list(catt_coefs_fs, catt_vcov_fs)
  #   }
  # }
  catt_coefs_rf <- results_object[[4]][[1]]
  catt_coefs_fs <- results_object[[5]][[1]]
  catt_vcov_rf <- results_object[[4]][[2]]
  catt_vcov_fs <- results_object[[5]][[2]]

  figdata[, file_order := .I] # just to fix the order in case any operations below alter it
  figdata[, outlier := 0L] # will only be relevant for the 'trim' == TRUE case

  if(use_custom_collapse_table == TRUE){
    collapse_table <- copy(collapse_table_dt)
  } else{
    if("first_two" %in% unique(figdata$grouping)){
      collapse_table <- data.table(a = c("pre",
                                         "on_impact",
                                         "first_two",
                                         "three_to_five",
                                         "post_avg"),
                                   b = c(list(-7:-1),
                                         list(1),
                                         list(1:2),
                                         list(3:5),
                                         list(1:5)))
    } else{
      collapse_table <- data.table(a = c("pre_7",
                                         "pre_5",
                                         "one_to_two",
                                         "one_to_three",
                                         "one_to_four",
                                         "post_avg",
                                         "two_to_three",
                                         "two_to_four",
                                         "two_to_five",
                                         "three_to_four",
                                         "three_to_five",
                                         "four_to_five"),
                                   b = c(list(-7:(-1 - anticipation)),
                                         list(-5:(-1 - anticipation)),
                                         list(1:2),
                                         list(1:3),
                                         list(1:4),
                                         list(1:5),
                                         list(2:3),
                                         list(2:4),
                                         list(2:5),
                                         list(3:4),
                                         list(3:5),
                                         list(4:5)))
    }
  }

  collapse_inputs <- copy(collapse_table)

  ## Calculate cohort-specifc vcov matrix for the ratio estimates
  # => then, calculate SEs for cohort-specific ratios
  catt_vcov_ratio <- makeRatioVcov(coefs_fs = copy(catt_coefs_fs),
                                   coefs_rf = copy(catt_coefs_rf),
                                   vcov_fs = copy(catt_vcov_fs),
                                   vcov_rf = copy(catt_vcov_rf))
  catt_coefs_ratio <- catt_coefs_rf/catt_coefs_fs
  catt_ses_ratio <- sqrt(diag(catt_vcov_ratio))
  names(catt_ses_ratio) <- names(catt_coefs_rf)

  event_times <- setdiff(figdata[, sort(unique(ref_event_time))], omitted_event_time)
  onset_times <- as.integer(figdata[rn == "catt", sort(unique(ref_onset_time))])

  min_onset_time <- min(onset_times)
  max_onset_time <- max(onset_times)

  catt_ses_ratio <- as.data.table(catt_ses_ratio, keep.rownames = TRUE)
  setnames(catt_ses_ratio, c("catt_ses_ratio"), c("cluster_se"))

  # copied below from main Wald_ES() function, with slight alteration
  catt_ses_ratio[!grepl("ref\\_onset\\_time", rn), e := min_onset_time]
  catt_ses_ratio[, rn := gsub("lead", "-", rn)]
  for (c in min_onset_time:max_onset_time) {
    catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), e := c]
    catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_et", c), "et", rn)]
    catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
  }
  rm(c)
  catt_ses_ratio[grepl("et", rn), event_time := as.integer(gsub("et", "", rn))]
  catt_ses_ratio[grepl("catt", rn), event_time := as.integer(gsub("catt", "", rn))]
  catt_ses_ratio[grepl("et", rn), rn := "event_time"]
  catt_ses_ratio[grepl("catt", rn), rn := "catt"]
  setnames(catt_ses_ratio, c("e","event_time"), c("ref_onset_time","ref_event_time"))
  catt_ses_ratio[, model := "ratio"]
  catt_ses_ratio[, rn := NULL]
  catt_ses_ratio[, ref_onset_time := as.character(ref_onset_time)] # to match 'figdata'

  figdata_catt_ratio <- figdata[rn == "catt" & model == "ratio"]
  figdata_att_ratio <- figdata[rn == "att" & model == "ratio" & !grepl("Collapsed", ref_onset_time)]
  figdata_collapsed_ratio <- figdata[rn == "att" & model == "ratio" & grepl("Collapsed", ref_onset_time)]
  figdata <- figdata[!(rn %in% c("catt","att") & (model == "ratio"))]

  figdata_catt_ratio[, cluster_se := NULL]
  figdata_catt_ratio <- merge(figdata_catt_ratio, catt_ses_ratio, by = c("ref_onset_time","ref_event_time","model"), sort = FALSE)
  rm(catt_ses_ratio)

  if(trim == TRUE){
    figdata_catt_ratio[, upper_bound := quantile(estimate, probs = 0.75) + (trim_iqr_factor*IQR(estimate)), by = list(ref_event_time, model)]
    figdata_catt_ratio[, lower_bound := quantile(estimate, probs = 0.25) - (trim_iqr_factor*IQR(estimate)), by = list(ref_event_time, model)]
    figdata_catt_ratio[, outlier := as.integer(estimate > upper_bound | estimate < lower_bound)]
    figdata_catt_ratio[, c("upper_bound", "lower_bound") := NULL]
  }

  figdata <- rbindlist(list(figdata, figdata_catt_ratio), use.names = TRUE)

  # Replace the current rn == "att" & model == "ratio" values with mean of ratios (rather than ratio of means)
  # cohort-weighted (omitting outliers if relevant)
  ew <- figdata_catt_ratio[outlier == 0,
                           list(ref_onset_time = "Equally-Weighted",
                                estimate = mean(estimate)),
                           by = list(ref_event_time, model)
                           ]
  cw <- figdata_catt_ratio[outlier == 0,
                           list(ref_onset_time = "Cohort-Weighted",
                                estimate = weighted.mean(estimate,
                                                         w = (catt_treated_unique_units / sum(catt_treated_unique_units)))),
                           by = list(ref_event_time, model)
                           ]
  cw2 <- figdata_catt_ratio[outlier == 0,
                            list(ref_onset_time = "Cohort-Weighted V2",
                                 estimate = weighted.mean(estimate,
                                                          w = (catt_total_unique_units / sum(catt_total_unique_units)))),
                            by = list(ref_event_time, model)
                            ]

  event_time_means_of_ratio <- rbindlist(list(ew, cw, cw2), use.names = TRUE)
  rm(ew, cw, cw2)

  figdata_att_ratio[, estimate := NA]
  figdata_att_ratio[event_time_means_of_ratio, on = list(ref_onset_time, ref_event_time, model), estimate := i.estimate]
  rm(event_time_means_of_ratio)

  # collapsed
  if (calculate_collapse_estimates == TRUE & homogeneous_ATT == FALSE) {

    collapse_input_dt <- copy(collapse_inputs)
    setnames(collapse_input_dt, c("name", "event_times"))

    ew <- list()
    cw <- list()
    cw2 <- list()
    j <- 0
    for (g in unique(na.omit(collapse_input_dt[["name"]]))) {
      j <- j + 1
      group_event_times <- setdiff(unique(na.omit(unlist(collapse_input_dt[name ==
                                                                             g][[2]]))), omitted_event_time)
      ddt <- figdata_catt_ratio[(ref_event_time %in% group_event_times) & outlier == 0]

      # R oddity -- sometimes 'grouping' column in ddt is logical rather than character
      # since local 'g' is character, change type below to avoid unnecessary warnings
      if(class(ddt$grouping) != "character"){
        ddt[, grouping := as.character(grouping)]
      }

      ddt[, `:=`(grouping, g)]
      ddt[, `:=`(unweighted_estimate, mean(estimate)),
          by = list(ref_event_time, model)]
      ddt[, `:=`(weighted_estimate_V1, weighted.mean(estimate,
                                                     w = (catt_treated_unique_units / sum(catt_treated_unique_units)))),
          by = list(ref_event_time, model)]
      ddt[, `:=`(weighted_estimate_V2, weighted.mean(estimate,
                                                     w = (catt_total_unique_units / sum(catt_total_unique_units)))),
          by = list(ref_event_time, model)]
      ddt[, `:=`(rowid, seq_len(.N)), by = list(ref_event_time,
                                                model)]
      result <- ddt[rowid == 1 | is.na(rowid)]

      ddt <- ddt[, list(ref_event_time, ref_onset_time, model, grouping)]

      ew[[j]] <- result[, list(grouping, model, unweighted_estimate)]
      ew[[j]][, `:=`(unweighted_estimate, mean(unweighted_estimate)), by = list(grouping, model)]
      ew[[j]][, `:=`(ref_onset_time, "Equally-Weighted + Collapsed")]
      ew[[j]][, `:=`(rn, "att")]
      ew[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
      setnames(ew[[j]], c("unweighted_estimate"), c("estimate"))
      ew[[j]] <- ew[[j]][rowid == 1 | is.na(rowid)]
      ew[[j]][, `:=`(rowid, NULL)]
      ew[[j]][, `:=`(cluster_se, 0)]

      cw[[j]] <- result[, list(grouping, model, weighted_estimate_V1)]
      cw[[j]][, `:=`(weighted_estimate_V1, mean(weighted_estimate_V1)), by = list(grouping, model)]
      cw[[j]][, `:=`(ref_onset_time, "Cohort-Weighted + Collapsed")]
      cw[[j]][, `:=`(rn, "att")]
      cw[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
      setnames(cw[[j]], c("weighted_estimate_V1"), c("estimate"))
      cw[[j]] <- cw[[j]][rowid == 1 | is.na(rowid)]
      cw[[j]][, `:=`(rowid, NULL)]
      cw[[j]][, `:=`(cluster_se, 0)]

      cw2[[j]] <- result[, list(grouping, model, weighted_estimate_V2)]
      cw2[[j]][, `:=`(weighted_estimate_V2, mean(weighted_estimate_V2)), by = list(grouping, model)]
      cw2[[j]][, `:=`(ref_onset_time, "Cohort-Weighted V2 + Collapsed")]
      cw2[[j]][, `:=`(rn, "att")]
      cw2[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
      setnames(cw2[[j]], c("weighted_estimate_V2"), c("estimate"))
      cw2[[j]] <- cw2[[j]][rowid == 1 | is.na(rowid)]
      cw2[[j]][, `:=`(rowid, NULL)]
      cw2[[j]][, `:=`(cluster_se, 0)]
      rm(result)
      rm(ddt)
    }
    rm(g, j, group_event_times)
    ew <- rbindlist(ew, use.names = TRUE)
    cw <- rbindlist(cw, use.names = TRUE)
    cw2 <- rbindlist(cw2, use.names = TRUE)

    collapsed_means_of_ratio <- rbindlist(list(ew, cw, cw2), use.names = TRUE)
    rm(ew, cw, cw2)
  }

  figdata_collapsed_ratio[, estimate := NA]
  figdata_collapsed_ratio[collapsed_means_of_ratio, on = list(ref_onset_time, grouping, model), estimate := i.estimate]
  rm(collapsed_means_of_ratio)

  non_outlier_table <- copy(figdata_catt_ratio[, list(ref_onset_time, ref_event_time, outlier)])
  rm(figdata_catt_ratio)

  figdata <- rbindlist(list(figdata, figdata_att_ratio, figdata_collapsed_ratio), use.names = TRUE)
  rm(figdata_att_ratio, figdata_collapsed_ratio)
  setorderv(figdata, "file_order")

  # calculate SEs for weighted avg estimates using catt_coefs and catt_vcov
  # the unique "Weighted" ref_event_time values represent the target list of parameters
  # for each such ref_event_time, want to extract the location (number) of the relevant parameters in catt_coefs
  # then will need to grab the relevant weight, and then construct the formula to supply to delta_method()
  # ==> copied (with modifications) from Wald_ES()
  # ==> still need to exclude the outliers of figdata_catt_ratio

  for(et in event_times){

    if(et < 0){
      lookfor <- sprintf("cattlead%s$", abs(et))
      # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  -1 and -19:-10 event times
    } else{
      lookfor <- sprintf("catt%s$", abs(et))
      # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  1 and 10:19 event times
    }
    coef_indices <- grep(lookfor, names(get(sprintf("catt_coefs_%s", model_type))))
    rm(lookfor)
    temp <- as.data.table(do.call(cbind, list(get(sprintf("catt_coefs_%s", model_type))[coef_indices], coef_indices)), keep.rownames = TRUE)
    setnames(temp, c("V1", "V2"), c("estimate", "coef_index"))
    rm(coef_indices)
    temp[, estimate := NULL]
    temp[, rn := gsub("lead", "-", rn)]
    for (c in min_onset_time:max_onset_time) {
      temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), ref_onset_time := c]
      temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
    }
    temp[grepl("catt", rn), ref_event_time := as.integer(gsub("catt", "", rn))]
    temp[, rn := NULL]
    temp[, ref_onset_time := as.character(ref_onset_time)]

    # restrict to non-outliers
    # -- only restrictive if 'trim' == TRUE
    temp <- merge(temp, non_outlier_table, by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
    temp <- temp[outlier == 0]

    # now merge in the weights
    temp <- merge(temp, figdata[rn == "catt" & model == model_type], by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
    temp <- temp[, list(ref_onset_time, ref_event_time, coef_index, catt_treated_unique_units, catt_total_unique_units)]
    temp[, equal_weight := 1 / .N]
    temp[, cohort_weight_V1 := catt_treated_unique_units / sum(catt_treated_unique_units)]
    temp[, cohort_weight_V2 := catt_total_unique_units / sum(catt_total_unique_units)]

    temp[, equal_w_formula_entry := sprintf("(%s*x%s)", equal_weight, coef_index)]
    temp[, cohort_w_v1_formula_entry := sprintf("(%s*x%s)", cohort_weight_V1, coef_index)]
    temp[, cohort_w_v2_formula_entry := sprintf("(%s*x%s)", cohort_weight_V2, coef_index)]

    equal_w_g_formula_input = paste0(temp$equal_w_formula_entry, collapse = "+")
    cohort_w_v1_g_formula_input = paste0(temp$cohort_w_v1_formula_entry, collapse = "+")
    cohort_w_v2_g_formula_input = paste0(temp$cohort_w_v2_formula_entry, collapse = "+")

    figdata[rn == "att" & model == model_type & cluster_se == 0 & ref_event_time == et & ref_onset_time == "Equally-Weighted",
            cluster_se := delta_method(g = as.formula(paste("~", equal_w_g_formula_input)),
                                       mean = get(sprintf("catt_coefs_%s", model_type)),
                                       cov = get(sprintf("catt_vcov_%s", model_type)),
                                       ses = TRUE
            )
            ]

    figdata[rn == "att" & model == model_type & cluster_se == 0 & ref_event_time == et & ref_onset_time == "Cohort-Weighted",
            cluster_se := delta_method(g = as.formula(paste("~", cohort_w_v1_g_formula_input)),
                                       mean = get(sprintf("catt_coefs_%s", model_type)),
                                       cov = get(sprintf("catt_vcov_%s", model_type)),
                                       ses = TRUE
            )
            ]

    figdata[rn == "att" & model == model_type & cluster_se == 0 & ref_event_time == et & ref_onset_time == "Cohort-Weighted V2",
            cluster_se := delta_method(g = as.formula(paste("~", cohort_w_v2_g_formula_input)),
                                       mean = get(sprintf("catt_coefs_%s", model_type)),
                                       cov = get(sprintf("catt_vcov_%s", model_type)),
                                       ses = TRUE
            )
            ]

    rm(temp, equal_w_g_formula_input, cohort_w_v1_g_formula_input, cohort_w_v2_g_formula_input)

  }
  rm(et)
  gc()

  # Now we calculate the collapsed estimates, if relevant
  if(calculate_collapse_estimates == TRUE & homogeneous_ATT == FALSE){

    collapse_input_dt <- copy(collapse_inputs)
    setnames(collapse_input_dt, c("name", "event_times"))

    for(g in unique(na.omit(collapse_input_dt[["name"]]))){

      # extract event_times and results corresponding to grouping
      # as we won't have an estimate for the omitted_event_time, exclude it below
      group_event_times <- setdiff(unique(na.omit(unlist(collapse_input_dt[name == g][[2]]))), omitted_event_time)
      ddt <- figdata[(ref_event_time %in% group_event_times) & (rn == "catt") & outlier == 0]
      ddt[, grouping := g]
      ddt[, rowid := seq_len(.N), by = list(ref_event_time, model)]
      ddt <- ddt[, list(ref_event_time, ref_onset_time, model, catt_treated_unique_units, catt_total_unique_units, grouping)]

      templist = list()
      i = 0
      for(et in group_event_times){

        i = i + 1

        if(et < 0){
          lookfor <- sprintf("cattlead%s$", abs(et))
          # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  -1 and -19:-10 event times
        } else{
          lookfor <- sprintf("catt%s$", abs(et))
          # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  1 and 10:19 event times
        }
        coef_indices <- grep(lookfor, names(get(sprintf("catt_coefs_%s", model_type))))
        rm(lookfor)
        temp <- as.data.table(do.call(cbind, list(get(sprintf("catt_coefs_%s", model_type))[coef_indices], coef_indices)), keep.rownames = TRUE)
        setnames(temp, c("V1", "V2"), c("estimate", "coef_index"))
        rm(coef_indices)
        temp[, estimate := NULL]
        temp[, rn := gsub("lead", "-", rn)]
        for (c in min_onset_time:max_onset_time) {
          temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), ref_onset_time := c]
          temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
        }
        rm(c)
        temp[grepl("catt", rn), ref_event_time := as.integer(gsub("catt", "", rn))]
        temp[, rn := NULL]
        temp[, ref_onset_time := as.character(ref_onset_time)]

        # restrict to non-outliers
        # -- only restrictive if 'trim' == TRUE
        temp <- merge(temp, non_outlier_table, by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
        temp <- temp[outlier == 0]

        # now merge in the within-event-time weights
        temp <- merge(temp, ddt[model == model_type], by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
        temp <- temp[, list(ref_onset_time, ref_event_time, coef_index, catt_treated_unique_units, catt_total_unique_units)]
        temp[, weight_V0 := 1 / .N]
        temp[, cohort_weight_V1 := catt_treated_unique_units / sum(catt_treated_unique_units)]
        temp[, cohort_weight_V2 := catt_total_unique_units / sum(catt_total_unique_units)]

        templist[[i]] <- copy(temp)
        rm(temp)
        gc()

      }
      rm(i, et, group_event_times)

      templist <- rbindlist(templist, use.names = TRUE)

      # Now add the across-event-time weights and calculate full (multiplicative) weights
      templist[, across_weight := (1 / uniqueN(ref_event_time))]
      templist[, full_weight_V0 := weight_V0 * across_weight]
      templist[, full_weight_V1 := cohort_weight_V1 * across_weight]
      templist[, full_weight_V2 := cohort_weight_V2 * across_weight]

      templist[, equal_w_formula_entry := sprintf("(%s*x%s)", full_weight_V0, coef_index)]
      templist[, cohort_w_v1_formula_entry := sprintf("(%s*x%s)", full_weight_V1, coef_index)]
      templist[, cohort_w_v2_formula_entry := sprintf("(%s*x%s)", full_weight_V2, coef_index)]

      formula_input_ew = paste0(templist$equal_w_formula_entry, collapse = "+")
      formula_input_cw = paste0(templist$cohort_w_v1_formula_entry, collapse = "+")
      formula_input_cw2 = paste0(templist$cohort_w_v2_formula_entry, collapse = "+")

      rm(templist)

      figdata[grouping == g & model == model_type & cluster_se == 0 & ref_onset_time == "Equally-Weighted + Collapsed",
              cluster_se := delta_method(g = as.formula(paste("~", formula_input_ew)),
                                         mean = get(sprintf("catt_coefs_%s", model_type)),
                                         cov = get(sprintf("catt_vcov_%s", model_type)),
                                         ses = TRUE
              )
              ]

      figdata[grouping == g & model == model_type & cluster_se == 0 & ref_onset_time == "Cohort-Weighted + Collapsed",
              cluster_se := delta_method(g = as.formula(paste("~", formula_input_cw)),
                                         mean = get(sprintf("catt_coefs_%s", model_type)),
                                         cov = get(sprintf("catt_vcov_%s", model_type)),
                                         ses = TRUE
              )
              ]

      figdata[grouping == g & model == model_type & cluster_se == 0 & ref_onset_time == "Cohort-Weighted V2 + Collapsed",
              cluster_se := delta_method(g = as.formula(paste("~", formula_input_cw2)),
                                         mean = get(sprintf("catt_coefs_%s", model_type)),
                                         cov = get(sprintf("catt_vcov_%s", model_type)),
                                         ses = TRUE
              )
              ]

      rm(formula_input_ew)
      rm(formula_input_cw)
      rm(formula_input_cw2)
      rm(ddt)
    }
    rm(g)
  }

  figdata[, outlier := NULL]
  setorderv(figdata, "file_order")
  figdata[, file_order := NULL]
  return(figdata)
}

addRatioSEs2 <- function(results_object,
                        omitted_event_time,
                        calculate_collapse_estimates,
                        homogeneous_ATT,
                        anticipation,
                        model_type,
                        trim = FALSE,
                        trim_iqr_factor = 1.5,
                        use_custom_collapse_table = FALSE,
                        collapse_table_dt = NULL){

  # results_object = copy(r_results);
  # omitted_event_time = omitted_event_time_update;
  # calculate_collapse_estimates = TRUE;
  # homogeneous_ATT = FALSE;
  # anticipation = anticipation;
  # model_type = "ratio"
  # trim = FALSE
  # trim_iqr_factor = 1.5

  # Load require packages
  # 'results_object' should be a data.table, so hopefuly already loaded
  if( !("package:data.table" %in% search()) ){
    library(data.table, verbose = TRUE)
  }

  figdata <- results_object[[1]]

  # R oddity -- sometimes 'grouping' column in figdata is logical rather than character
  # want it to be character
  if(("grouping" %in% colnames(figdata)) & (class(figdata$grouping) != "character")){
    figdata[, grouping := as.character(grouping)]
  }

  if ( !("grouping" %in% colnames(figdata))) {
        figdata[, grouping := as.character(NA)]
    }

  # from Wald_ES internals
  # if (homogeneous_ATT == FALSE) {
  #   return_list[[4]] <- list(catt_coefs_rf, catt_vcov_rf)
  #   if (!is.null(endog_var)) {
  #     return_list[[5]] <- list(catt_coefs_fs, catt_vcov_fs)
  #   }
  # }
  catt_coefs_rf <- results_object[[4]][[1]]
  catt_coefs_fs <- results_object[[5]][[1]]
  catt_vcov_rf <- results_object[[4]][[2]]
  catt_vcov_fs <- results_object[[5]][[2]]

  figdata[, file_order := .I] # just to fix the order in case any operations below alter it
  figdata[, outlier := 0L] # will only be relevant for the 'trim' == TRUE case

  if(use_custom_collapse_table == TRUE){
    collapse_table <- copy(collapse_table_dt)
  } else{
    if("first_two" %in% unique(figdata$grouping)){
      collapse_table <- data.table(a = c("pre",
                                         "on_impact",
                                         "first_two",
                                         "three_to_five",
                                         "post_avg"),
                                   b = c(list(-7:-1),
                                         list(1),
                                         list(1:2),
                                         list(3:5),
                                         list(1:5)))
    } else{
      collapse_table <- data.table(a = c("pre_7",
                                         "pre_5",
                                         "one_to_two",
                                         "one_to_three",
                                         "one_to_four",
                                         "post_avg",
                                         "two_to_three",
                                         "two_to_four",
                                         "two_to_five",
                                         "three_to_four",
                                         "three_to_five",
                                         "four_to_five"),
                                   b = c(list(-7:(-1 - anticipation)),
                                         list(-5:(-1 - anticipation)),
                                         list(1:2),
                                         list(1:3),
                                         list(1:4),
                                         list(1:5),
                                         list(2:3),
                                         list(2:4),
                                         list(2:5),
                                         list(3:4),
                                         list(3:5),
                                         list(4:5)))
    }
  }

  collapse_inputs <- copy(collapse_table)

  ## Calculate cohort-specifc vcov matrix for the ratio estimates
  # => then, calculate SEs for cohort-specific ratios
  catt_vcov_ratio <- makeRatioVcov(coefs_fs = copy(catt_coefs_fs),
                                   coefs_rf = copy(catt_coefs_rf),
                                   vcov_fs = copy(catt_vcov_fs),
                                   vcov_rf = copy(catt_vcov_rf))
  catt_coefs_ratio <- catt_coefs_rf/catt_coefs_fs
  catt_ses_ratio <- sqrt(diag(catt_vcov_ratio))
  names(catt_ses_ratio) <- names(catt_coefs_rf)

  event_times <- setdiff(figdata[, sort(unique(ref_event_time))], omitted_event_time)
  onset_times <- as.integer(figdata[rn == "catt", sort(unique(ref_onset_time))])

  min_onset_time <- min(onset_times)
  max_onset_time <- max(onset_times)

  catt_ses_ratio <- as.data.table(catt_ses_ratio, keep.rownames = TRUE)
  setnames(catt_ses_ratio, c("catt_ses_ratio"), c("cluster_se"))

  # copied below from main Wald_ES() function, with slight alteration
  catt_ses_ratio[!grepl("ref\\_onset\\_time", rn), e := min_onset_time]
  catt_ses_ratio[, rn := gsub("lead", "-", rn)]
  for (c in min_onset_time:max_onset_time) {
    catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), e := c]
    catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_et", c), "et", rn)]
    catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
  }
  rm(c)
  catt_ses_ratio[grepl("et", rn), event_time := as.integer(gsub("et", "", rn))]
  catt_ses_ratio[grepl("catt", rn), event_time := as.integer(gsub("catt", "", rn))]
  catt_ses_ratio[grepl("et", rn), rn := "event_time"]
  catt_ses_ratio[grepl("catt", rn), rn := "catt"]
  setnames(catt_ses_ratio, c("e","event_time"), c("ref_onset_time","ref_event_time"))
  catt_ses_ratio[, model := "ratio"]
  catt_ses_ratio[, rn := NULL]
  catt_ses_ratio[, ref_onset_time := as.character(ref_onset_time)] # to match 'figdata'

  figdata_catt_ratio <- figdata[rn == "catt" & model == "ratio"]
  figdata <- figdata[!(rn %in% c("catt","att") & (model == "ratio"))]

  figdata_catt_ratio[, cluster_se := NULL]
  figdata_catt_ratio <- merge(figdata_catt_ratio, catt_ses_ratio, by = c("ref_onset_time","ref_event_time","model"), sort = FALSE)
  rm(catt_ses_ratio)

  if(trim == TRUE){
    figdata_catt_ratio[, upper_bound := quantile(estimate, probs = 0.75) + (trim_iqr_factor*IQR(estimate)), by = list(ref_event_time, model)]
    figdata_catt_ratio[, lower_bound := quantile(estimate, probs = 0.25) - (trim_iqr_factor*IQR(estimate)), by = list(ref_event_time, model)]
    figdata_catt_ratio[, outlier := as.integer(estimate > upper_bound | estimate < lower_bound)]
    figdata_catt_ratio[, c("upper_bound", "lower_bound") := NULL]
  }

  figdata <- rbindlist(list(figdata, figdata_catt_ratio), use.names = TRUE)

  # Replace the current rn == "att" & model == "ratio" values with mean of ratios (rather than ratio of means)
  # cohort-weighted (omitting outliers if relevant)
  ew <- figdata_catt_ratio[outlier == 0,
                           list(ref_onset_time = "Equally-Weighted",
                                estimate = mean(estimate)),
                           by = list(ref_event_time, model)
                           ]
  cw <- figdata_catt_ratio[outlier == 0,
                           list(ref_onset_time = "Cohort-Weighted",
                                estimate = weighted.mean(estimate,
                                                         w = (catt_treated_unique_units / sum(catt_treated_unique_units)))),
                           by = list(ref_event_time, model)
                           ]
  cw2 <- figdata_catt_ratio[outlier == 0,
                            list(ref_onset_time = "Cohort-Weighted V2",
                                 estimate = weighted.mean(estimate,
                                                          w = (catt_total_unique_units / sum(catt_total_unique_units)))),
                            by = list(ref_event_time, model)
                            ]

  figdata_att_ratio <- rbindlist(list(ew, cw, cw2), use.names = TRUE)
  rm(ew, cw, cw2)

  # collapsed
  if (calculate_collapse_estimates == TRUE & homogeneous_ATT == FALSE) {

    collapse_input_dt <- copy(collapse_inputs)
    setnames(collapse_input_dt, c("name", "event_times"))

    ew <- list()
    cw <- list()
    cw2 <- list()
    j <- 0
    for (g in unique(na.omit(collapse_input_dt[["name"]]))) {
      j <- j + 1
      group_event_times <- setdiff(unique(na.omit(unlist(collapse_input_dt[name ==
                                                                             g][[2]]))), omitted_event_time)
      ddt <- figdata_catt_ratio[(ref_event_time %in% group_event_times) & outlier == 0]

      # R oddity -- sometimes 'grouping' column in ddt is logical rather than character
      # since local 'g' is character, change type below to avoid unnecessary warnings
      if(class(ddt$grouping) != "character"){
        ddt[, grouping := as.character(grouping)]
      }

      ddt[, `:=`(grouping, g)]
      ddt[, `:=`(unweighted_estimate, mean(estimate)),
          by = list(ref_event_time, model)]
      ddt[, `:=`(weighted_estimate_V1, weighted.mean(estimate,
                                                     w = (catt_treated_unique_units / sum(catt_treated_unique_units)))),
          by = list(ref_event_time, model)]
      ddt[, `:=`(weighted_estimate_V2, weighted.mean(estimate,
                                                     w = (catt_total_unique_units / sum(catt_total_unique_units)))),
          by = list(ref_event_time, model)]
      ddt[, `:=`(rowid, seq_len(.N)), by = list(ref_event_time,
                                                model)]
      result <- ddt[rowid == 1 | is.na(rowid)]

      ddt <- ddt[, list(ref_event_time, ref_onset_time, model, grouping)]

      ew[[j]] <- result[, list(grouping, model, unweighted_estimate)]
      ew[[j]][, `:=`(unweighted_estimate, mean(unweighted_estimate)), by = list(grouping, model)]
      ew[[j]][, `:=`(ref_onset_time, "Equally-Weighted + Collapsed")]
      ew[[j]][, `:=`(rn, "att")]
      ew[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
      setnames(ew[[j]], c("unweighted_estimate"), c("estimate"))
      ew[[j]] <- ew[[j]][rowid == 1 | is.na(rowid)]
      ew[[j]][, `:=`(rowid, NULL)]
      ew[[j]][, `:=`(cluster_se, 0)]

      cw[[j]] <- result[, list(grouping, model, weighted_estimate_V1)]
      cw[[j]][, `:=`(weighted_estimate_V1, mean(weighted_estimate_V1)), by = list(grouping, model)]
      cw[[j]][, `:=`(ref_onset_time, "Cohort-Weighted + Collapsed")]
      cw[[j]][, `:=`(rn, "att")]
      cw[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
      setnames(cw[[j]], c("weighted_estimate_V1"), c("estimate"))
      cw[[j]] <- cw[[j]][rowid == 1 | is.na(rowid)]
      cw[[j]][, `:=`(rowid, NULL)]
      cw[[j]][, `:=`(cluster_se, 0)]

      cw2[[j]] <- result[, list(grouping, model, weighted_estimate_V2)]
      cw2[[j]][, `:=`(weighted_estimate_V2, mean(weighted_estimate_V2)), by = list(grouping, model)]
      cw2[[j]][, `:=`(ref_onset_time, "Cohort-Weighted V2 + Collapsed")]
      cw2[[j]][, `:=`(rn, "att")]
      cw2[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
      setnames(cw2[[j]], c("weighted_estimate_V2"), c("estimate"))
      cw2[[j]] <- cw2[[j]][rowid == 1 | is.na(rowid)]
      cw2[[j]][, `:=`(rowid, NULL)]
      cw2[[j]][, `:=`(cluster_se, 0)]
      rm(result)
      rm(ddt)
    }
    rm(g, j, group_event_times)
    ew <- rbindlist(ew, use.names = TRUE)
    cw <- rbindlist(cw, use.names = TRUE)
    cw2 <- rbindlist(cw2, use.names = TRUE)

    figdata_collapsed_ratio <- rbindlist(list(ew, cw, cw2), use.names = TRUE)
    rm(ew, cw, cw2)
  }

  non_outlier_table <- copy(figdata_catt_ratio[, list(ref_onset_time, ref_event_time, outlier)])
  rm(figdata_catt_ratio)

  figdata <- rbindlist(list(figdata, figdata_att_ratio, figdata_collapsed_ratio), use.names = TRUE, fill = TRUE)
  rm(figdata_att_ratio, figdata_collapsed_ratio)
  setorderv(figdata, "file_order")

  # calculate SEs for weighted avg estimates using catt_coefs and catt_vcov
  # the unique "Weighted" ref_event_time values represent the target list of parameters
  # for each such ref_event_time, want to extract the location (number) of the relevant parameters in catt_coefs
  # then will need to grab the relevant weight, and then construct the formula to supply to delta_method()
  # ==> copied (with modifications) from Wald_ES()
  # ==> still need to exclude the outliers of figdata_catt_ratio

  for(et in event_times){

    if(et < 0){
      lookfor <- sprintf("cattlead%s$", abs(et))
      # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  -1 and -19:-10 event times
    } else{
      lookfor <- sprintf("catt%s$", abs(et))
      # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  1 and 10:19 event times
    }
    coef_indices <- grep(lookfor, names(get(sprintf("catt_coefs_%s", model_type))))
    rm(lookfor)
    temp <- as.data.table(do.call(cbind, list(get(sprintf("catt_coefs_%s", model_type))[coef_indices], coef_indices)), keep.rownames = TRUE)
    setnames(temp, c("V1", "V2"), c("estimate", "coef_index"))
    rm(coef_indices)
    temp[, estimate := NULL]
    temp[, rn := gsub("lead", "-", rn)]
    for (c in min_onset_time:max_onset_time) {
      temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), ref_onset_time := c]
      temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
    }
    temp[grepl("catt", rn), ref_event_time := as.integer(gsub("catt", "", rn))]
    temp[, rn := NULL]
    temp[, ref_onset_time := as.character(ref_onset_time)]

    # restrict to non-outliers
    # -- only restrictive if 'trim' == TRUE
    temp <- merge(temp, non_outlier_table, by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
    temp <- temp[outlier == 0]

    # now merge in the weights
    temp <- merge(temp, figdata[rn == "catt" & model == model_type], by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
    temp <- temp[, list(ref_onset_time, ref_event_time, coef_index, catt_treated_unique_units, catt_total_unique_units)]
    temp[, equal_weight := 1 / .N]
    temp[, cohort_weight_V1 := catt_treated_unique_units / sum(catt_treated_unique_units)]
    temp[, cohort_weight_V2 := catt_total_unique_units / sum(catt_total_unique_units)]

    temp[, equal_w_formula_entry := sprintf("(%s*x%s)", equal_weight, coef_index)]
    temp[, cohort_w_v1_formula_entry := sprintf("(%s*x%s)", cohort_weight_V1, coef_index)]
    temp[, cohort_w_v2_formula_entry := sprintf("(%s*x%s)", cohort_weight_V2, coef_index)]

    equal_w_g_formula_input = paste0(temp$equal_w_formula_entry, collapse = "+")
    cohort_w_v1_g_formula_input = paste0(temp$cohort_w_v1_formula_entry, collapse = "+")
    cohort_w_v2_g_formula_input = paste0(temp$cohort_w_v2_formula_entry, collapse = "+")

    figdata[rn == "att" & model == model_type & cluster_se == 0 & ref_event_time == et & ref_onset_time == "Equally-Weighted",
            cluster_se := delta_method(g = as.formula(paste("~", equal_w_g_formula_input)),
                                       mean = get(sprintf("catt_coefs_%s", model_type)),
                                       cov = get(sprintf("catt_vcov_%s", model_type)),
                                       ses = TRUE
            )
            ]

    figdata[rn == "att" & model == model_type & cluster_se == 0 & ref_event_time == et & ref_onset_time == "Cohort-Weighted",
            cluster_se := delta_method(g = as.formula(paste("~", cohort_w_v1_g_formula_input)),
                                       mean = get(sprintf("catt_coefs_%s", model_type)),
                                       cov = get(sprintf("catt_vcov_%s", model_type)),
                                       ses = TRUE
            )
            ]

    figdata[rn == "att" & model == model_type & cluster_se == 0 & ref_event_time == et & ref_onset_time == "Cohort-Weighted V2",
            cluster_se := delta_method(g = as.formula(paste("~", cohort_w_v2_g_formula_input)),
                                       mean = get(sprintf("catt_coefs_%s", model_type)),
                                       cov = get(sprintf("catt_vcov_%s", model_type)),
                                       ses = TRUE
            )
            ]

    rm(temp, equal_w_g_formula_input, cohort_w_v1_g_formula_input, cohort_w_v2_g_formula_input)

  }
  rm(et)
  gc()

  # Now we calculate the collapsed estimates, if relevant
  if(calculate_collapse_estimates == TRUE & homogeneous_ATT == FALSE){

    collapse_input_dt <- copy(collapse_inputs)
    setnames(collapse_input_dt, c("name", "event_times"))

    for(g in unique(na.omit(collapse_input_dt[["name"]]))){

      # extract event_times and results corresponding to grouping
      # as we won't have an estimate for the omitted_event_time, exclude it below
      group_event_times <- setdiff(unique(na.omit(unlist(collapse_input_dt[name == g][[2]]))), omitted_event_time)
      ddt <- figdata[(ref_event_time %in% group_event_times) & (rn == "catt") & outlier == 0]
      ddt[, grouping := g]
      ddt[, rowid := seq_len(.N), by = list(ref_event_time, model)]
      ddt <- ddt[, list(ref_event_time, ref_onset_time, model, catt_treated_unique_units, catt_total_unique_units, grouping)]

      templist = list()
      i = 0
      for(et in group_event_times){

        i = i + 1

        if(et < 0){
          lookfor <- sprintf("cattlead%s$", abs(et))
          # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  -1 and -19:-10 event times
        } else{
          lookfor <- sprintf("catt%s$", abs(et))
          # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  1 and 10:19 event times
        }
        coef_indices <- grep(lookfor, names(get(sprintf("catt_coefs_%s", model_type))))
        rm(lookfor)
        temp <- as.data.table(do.call(cbind, list(get(sprintf("catt_coefs_%s", model_type))[coef_indices], coef_indices)), keep.rownames = TRUE)
        setnames(temp, c("V1", "V2"), c("estimate", "coef_index"))
        rm(coef_indices)
        temp[, estimate := NULL]
        temp[, rn := gsub("lead", "-", rn)]
        for (c in min_onset_time:max_onset_time) {
          temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), ref_onset_time := c]
          temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
        }
        rm(c)
        temp[grepl("catt", rn), ref_event_time := as.integer(gsub("catt", "", rn))]
        temp[, rn := NULL]
        temp[, ref_onset_time := as.character(ref_onset_time)]

        # restrict to non-outliers
        # -- only restrictive if 'trim' == TRUE
        temp <- merge(temp, non_outlier_table, by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
        temp <- temp[outlier == 0]

        # now merge in the within-event-time weights
        temp <- merge(temp, ddt[model == model_type], by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
        temp <- temp[, list(ref_onset_time, ref_event_time, coef_index, catt_treated_unique_units, catt_total_unique_units)]
        temp[, weight_V0 := 1 / .N]
        temp[, cohort_weight_V1 := catt_treated_unique_units / sum(catt_treated_unique_units)]
        temp[, cohort_weight_V2 := catt_total_unique_units / sum(catt_total_unique_units)]

        templist[[i]] <- copy(temp)
        rm(temp)
        gc()

      }
      rm(i, et, group_event_times)

      templist <- rbindlist(templist, use.names = TRUE)

      # Now add the across-event-time weights and calculate full (multiplicative) weights
      templist[, across_weight := (1 / uniqueN(ref_event_time))]
      templist[, full_weight_V0 := weight_V0 * across_weight]
      templist[, full_weight_V1 := cohort_weight_V1 * across_weight]
      templist[, full_weight_V2 := cohort_weight_V2 * across_weight]

      templist[, equal_w_formula_entry := sprintf("(%s*x%s)", full_weight_V0, coef_index)]
      templist[, cohort_w_v1_formula_entry := sprintf("(%s*x%s)", full_weight_V1, coef_index)]
      templist[, cohort_w_v2_formula_entry := sprintf("(%s*x%s)", full_weight_V2, coef_index)]

      formula_input_ew = paste0(templist$equal_w_formula_entry, collapse = "+")
      formula_input_cw = paste0(templist$cohort_w_v1_formula_entry, collapse = "+")
      formula_input_cw2 = paste0(templist$cohort_w_v2_formula_entry, collapse = "+")

      rm(templist)

      figdata[grouping == g & model == model_type & cluster_se == 0 & ref_onset_time == "Equally-Weighted + Collapsed",
              cluster_se := delta_method(g = as.formula(paste("~", formula_input_ew)),
                                         mean = get(sprintf("catt_coefs_%s", model_type)),
                                         cov = get(sprintf("catt_vcov_%s", model_type)),
                                         ses = TRUE
              )
              ]

      figdata[grouping == g & model == model_type & cluster_se == 0 & ref_onset_time == "Cohort-Weighted + Collapsed",
              cluster_se := delta_method(g = as.formula(paste("~", formula_input_cw)),
                                         mean = get(sprintf("catt_coefs_%s", model_type)),
                                         cov = get(sprintf("catt_vcov_%s", model_type)),
                                         ses = TRUE
              )
              ]

      figdata[grouping == g & model == model_type & cluster_se == 0 & ref_onset_time == "Cohort-Weighted V2 + Collapsed",
              cluster_se := delta_method(g = as.formula(paste("~", formula_input_cw2)),
                                         mean = get(sprintf("catt_coefs_%s", model_type)),
                                         cov = get(sprintf("catt_vcov_%s", model_type)),
                                         ses = TRUE
              )
              ]

      rm(formula_input_ew)
      rm(formula_input_cw)
      rm(formula_input_cw2)
      rm(ddt)
    }
    rm(g)
  }

  figdata[, outlier := NULL]
  setorderv(figdata, "file_order")
  figdata[, file_order := NULL]
  return(figdata)
}